{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train LSTM on TBI signal data\n",
    "\n",
    "Code to train an LSTM on the TBI data to predict hypoxemia in the future (low SAO2).\n",
    "\n",
    "Note that the data is private and we are unable to make it publicly available in this repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,4\"\n",
    "\n",
    "import numpy as np\n",
    "from tbi_downstream_prediction import split_data\n",
    "\n",
    "import keras\n",
    "from keras.utils import multi_gpu_model\n",
    "from keras.layers import Input, LSTM, Dense, Dropout\n",
    "from keras.models import Sequential, load_model, Model\n",
    "from matplotlib import cm, pyplot as plt\n",
    "from sklearn import metrics\n",
    "from os.path import expanduser as eu\n",
    "from os.path import isfile, join\n",
    "import numpy as np\n",
    "import random, time\n",
    "\n",
    "import tensorflow as tf\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "config = tf.ConfigProto(allow_soft_placement=True,gpu_options = tf.GPUOptions(allow_growth=True))\n",
    "set_session(tf.Session(config=config))\n",
    "GPUNUM = len(os.environ[\"CUDA_VISIBLE_DEVICES\"].split(\",\"))\n",
    "\n",
    "PATH = \"/homes/gws/hughchen/phase/tbi_subset/\"\n",
    "DPATH = PATH+\"tbi/processed_data/hypoxemia/\"\n",
    "data_type = \"raw[top11]\"\n",
    "\n",
    "feat_lst = [\"ECGRATE\", \"ETCO2\", \"ETSEV\", \"ETSEVO\", \"FIO2\", \"NIBPD\", \"NIBPM\", \n",
    "            \"NIBPS\",\"PEAK\", \"PEEP\", \"PIP\", \"RESPRATE\", \"SAO2\", \"TEMP1\", \"TV\"]\n",
    "\n",
    "# Exclude these features\n",
    "exclude_feat_lst = [\"ETSEV\", \"PIP\", \"PEEP\", \"TV\"]\n",
    "feat_inds = np.array([feat_lst.index(feat) for feat in feat_lst if feat not in exclude_feat_lst])\n",
    "feat_lst2 = [feat for feat in feat_lst if feat not in exclude_feat_lst]\n",
    "\n",
    "y_tbi = np.load(DPATH+\"tbiy.npy\")\n",
    "X_tbi = np.load(DPATH+\"X_tbi_imp_standard.npy\")\n",
    "\n",
    "X_tbi2 = X_tbi[:,feat_inds,:]\n",
    "(X_test, y_test, X_valid, y_valid, X_train, y_train) = split_data(DPATH,X_tbi2,y_tbi,flatten=False)\n",
    "\n",
    "PATH = \"/homes/gws/hughchen/phase/tbi_subset/\"\n",
    "RESULTPATH = PATH+\"results/\"\n",
    "label_type = \"desat_bool92_5_nodesat\"\n",
    "lstm_type = \"biglstmdropoutv3_{}\".format(label_type)\n",
    "RESDIR = '{}{}/'.format(RESULTPATH, lstm_type)\n",
    "if not os.path.exists(RESDIR): os.makedirs(RESDIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameter tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set of hyperparameters to tune over\n",
    "node_lst   = [50, 100, 200]\n",
    "opt_lst    = [\"rmsprop\", \"adam\", \"sgd\"]\n",
    "lr_lst     = [0.01, 0.001, 0.0001]\n",
    "drop_lst   = [0.3, 0.5, 0.7]\n",
    "ISTUNE = True\n",
    "\n",
    "# Greedily find best opt\n",
    "min_loss_lst = [\n",
    "    create_train_model(opt_name=\"rmsprop\",lr=0.001,drop=0.5,\n",
    "                       b_size=1000,nodesize=para)\n",
    "    for para in node_lst\n",
    "]\n",
    "best_nodesize = node_lst[min_loss_lst.index(min(min_loss_lst))]\n",
    "print(\"best best_nodesize: {}\".format(best_nodesize))\n",
    "\n",
    "# Greedily find best opt\n",
    "min_loss_lst = [\n",
    "    create_train_model(opt_name=para,lr=0.001,drop=0.5,\n",
    "                       b_size=1000,nodesize=best_nodesize)\n",
    "    for para in opt_lst\n",
    "]\n",
    "best_optname = opt_lst[min_loss_lst.index(min(min_loss_lst))]\n",
    "print(\"best optname: {}\".format(best_optname))\n",
    "\n",
    "# Greedily find best lr\n",
    "min_loss_lst = [\n",
    "    create_train_model(opt_name=best_optname,lr=para,drop=0.5,\n",
    "                       b_size=1000,nodesize=best_nodesize)\n",
    "    for para in lr_lst\n",
    "]\n",
    "best_lr = lr_lst[min_loss_lst.index(min(min_loss_lst))]\n",
    "print(\"best lr: {}\".format(best_lr))\n",
    "\n",
    "# Greedily find best dropout\n",
    "min_loss_lst = [\n",
    "    create_train_model(opt_name=best_optname,lr=best_lr,drop=para,\n",
    "                       b_size=1000,nodesize=best_nodesize)\n",
    "    for para in drop_lst\n",
    "]\n",
    "best_drop = drop_lst[min_loss_lst.index(min(min_loss_lst))]\n",
    "print(\"best drop: {}\".format(best_drop))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model with the best hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_train_model(opt_name=best_optname,lr=best_lr,drop=best_drop,\n",
    "                   b_size=1000,nodesize=best_nodesize,epoch_num=200,\n",
    "                   is_tune=False)\n",
    "eval_test_model(opt_name=best_optname,lr=best_lr,drop=best_drop,\n",
    "                b_size=1000,nodesize=best_nodesize,epoch_num=200)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
